import pandas as pd
import numpy as np
import pickle
import time
import scipy

import torchvision.transforms as transforms

import random

import argparse
import os
import random

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision
import torch.nn.functional as F

import torchvision.models as models

from sklearn.neighbors import KNeighborsClassifier

from tqdm import tqdm


DEVICE = 'cpu'


def main():
    collect_data()


def load_dataloaders(b_size=1, shuffle=False):
	train_transform = transforms.Compose(
		[transforms.ToTensor(),
		 transforms.Normalize((0.5,), (0.5,), (0.5,))])

	test_transform = transforms.Compose(
		[transforms.ToTensor(),
		 transforms.Normalize((0.5,), (0.5,), (0.5,))])

	train_set = torchvision.datasets.FashionMNIST(
		root='./data/FashionMNIST',
		train=True,
		download=True,
		transform=train_transform)

	train_loader = torch.utils.data.DataLoader(
		train_set,
		batch_size=b_size,
		shuffle=shuffle,
    num_workers=2)

	test_set = torchvision.datasets.FashionMNIST(
		root='./data/FashionMNIST',
		train=False,
		download=True,
		transform=test_transform)

	test_loader = torch.utils.data.DataLoader(
		test_set,
		batch_size=b_size,
		shuffle=False,
    num_workers=2)
 
	return train_loader, test_loader



class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # input is Z, going into a convolution
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2)
        self.bn1   = nn.BatchNorm2d(8)

        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2)
        self.bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm2d(32)

        self.conv4 = nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2)
        self.bn4 = nn.BatchNorm2d(64)

        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        
        self.avgpool = nn.AvgPool2d(7)
        self.linear = nn.Linear(128, 10)    
        self.relu = nn.ReLU()
        
    def forward(self, I):   
        x = self.relu(self.bn1(self.conv1(I)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        C = self.relu(self.bn5(self.conv5(x)))
        x = self.avgpool(C)
        x = x.view(x.shape[0], x.shape[1])
        logits = self.linear(x)
        return logits, x, C



def collect_data():


    train_loader, test_loader = load_dataloaders(b_size=1, shuffle=False)


    netC = CNN()
    netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device(DEVICE)))
    netC = netC.eval()


    weights = netC.linear.weight


    X_train_c = list()
    X_test_c = list()
    X_train_x = list()
    X_test_x = list()
    X_train_y = list()
    X_test_y = list()
    X_train_C = list()
    X_test_C = list()


    ### Collect Twin Data and Feature Map Data


    #### Iterate just one image at a time

    for i, data in enumerate(tqdm(train_loader)):
        img, label = data
        img, label = img.to(DEVICE), label.to(DEVICE)
        logits, x, C = netC(img)
        y_hat = torch.argmax(logits).item()
        c = torch.mul(x[0], weights[y_hat])
        
        X_train_c.append(c.cpu().detach().numpy().tolist())
        X_train_C.append(C.detach().numpy().tolist())
        X_train_x.append(x[0].cpu().detach().numpy().tolist())
        X_train_y.append(y_hat)
        

    #### Iterate just one image at a time

    for i, data in enumerate(tqdm(test_loader)):
        img, label = data
        img, label = img.to(DEVICE), label.to(DEVICE)
        logits, x, C = netC(img)
        y_hat = torch.argmax(logits).item()
        c = torch.mul(x[0], weights[y_hat])
        
        X_test_c.append(c.cpu().detach().numpy().tolist())
        X_test_C.append(C.detach().numpy().tolist())
        X_test_x.append(x[0].cpu().detach().numpy().tolist())
        X_test_y.append(y_hat)
        

    X_train_c = np.array(X_train_c)
    X_test_c = np.array(X_test_c)
    X_train_C = np.array(X_train_C)
    X_test_C = np.array(X_test_C)
    X_train_x = np.array(X_train_x)
    X_test_x = np.array(X_test_x)
    X_train_y = np.array(X_train_y)
    X_test_y = np.array(X_test_y)

    np.save("data/X_train_cont.npy", X_train_c)
    np.save("data/X_test_cont.npy",  X_test_c)
    np.save("data/X_train_conv.npy", X_train_C)
    np.save("data/X_test_conv.npy",  X_test_C)
    np.save("data/X_train_x.npy",    X_train_x)
    np.save("data/X_test_x.npy",     X_test_x)
    np.save("data/X_train_y.npy",    X_train_y)
    np.save("data/X_test_y.npy",     X_test_y)

    print(X_train_c.shape)
    print(X_train_x.shape)
    print(X_train_C.shape)
    print(X_train_y.shape)


if __name__ == '__main__':
    main()



